import json
import argparse
import os
from collections import defaultdict
from sklearn.metrics import cohen_kappa_score

def compute_iou(boxA, boxB):
    """
    Compute Intersection over Union between two bboxes in COCO [x, y, w, h] format.
    """
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[0] + boxA[2], boxB[0] + boxB[2])
    yB = min(boxA[1] + boxA[3], boxB[1] + boxB[3])

    interW = max(0, xB - xA)
    interH = max(0, yB - yA)
    interArea = interW * interH

    boxAArea = boxA[2] * boxA[3]
    boxBArea = boxB[2] * boxB[3]
    unionArea = boxAArea + boxBArea - interArea

    return interArea / unionArea if unionArea > 0 else 0.0


def load_annotations(source):
    """
    Load annotations from a COCO-style JSON file or directory of such files.
    Returns a list of annotation dicts.
    """
    anns = []
    if os.path.isdir(source):
        for fname in os.listdir(source):
            if fname.lower().endswith('.json'):
                path = os.path.join(source, fname)
                with open(path, 'r', encoding='utf-8') as f:
                    data = json.load(f)
                anns.extend(data.get('annotations', []))
    else:
        with open(source, 'r', encoding='utf-8') as f:
            data = json.load(f)
        anns = data.get('annotations', [])
    return anns


def evaluate(gt_annos, user_annos, iou_threshold=0.5):
    """
    Evaluate IoU and attribute accuracy, returning both detection and attribute data.
    Returns:
      avg_iou_all: float
      avg_iou_matched: float
      pose_pairs: list of (gt_pose, usr_pose)
      vis_pairs: list of (gt_vis, usr_vis)
      matched: int
      total_gt: int
      pose_acc_by_class: dict
      vis_acc_by_class: dict
    """
    gt_by_img = defaultdict(list)
    usr_by_img = defaultdict(list)
    for ann in gt_annos:
        gt_by_img[ann['image_id']].append(ann)
    for ann in user_annos:
        usr_by_img[ann['image_id']].append(ann)

    all_ious = []
    matched_ious = []
    pose_pairs = []
    vis_pairs = []
    pose_count = defaultdict(int)
    pose_correct = defaultdict(int)
    vis_count = defaultdict(int)
    vis_correct = defaultdict(int)
    matched = 0

    for image_id, gts in gt_by_img.items():
        usrs = usr_by_img.get(image_id, [])
        for gt in gts:
            best_iou = 0.0
            best_usr = None
            for usr in usrs:
                iou = compute_iou(gt['bbox'], usr['bbox'])
                if iou > best_iou:
                    best_iou = iou
                    best_usr = usr
            all_ious.append(best_iou)

            if best_iou >= iou_threshold and best_usr is not None:
                matched += 1
                matched_ious.append(best_iou)
                # Pose
                gt_pose = str(gt.get('attributes', {}).get('pose', '')).strip().lower()
                usr_pose = str(best_usr.get('attributes', {}).get('pose', '')).strip().lower()
                pose_pairs.append((gt_pose, usr_pose))
                pose_count[gt_pose] += 1
                if gt_pose == usr_pose:
                    pose_correct[gt_pose] += 1
                # Visibility
                gt_vis = str(gt.get('attributes', {}).get('visibility_level')).strip()
                usr_vis = str(best_usr.get('attributes', {}).get('visibility_level')).strip()
                vis_pairs.append((gt_vis, usr_vis))
                vis_count[gt_vis] += 1
                if gt_vis == usr_vis:
                    vis_correct[gt_vis] += 1

    total_gt = len(all_ious)
    avg_iou_all = sum(all_ious) / total_gt if total_gt else 0.0
    avg_iou_matched = sum(matched_ious) / matched if matched else 0.0
    pose_acc_by_class = {pose: (pose_correct[pose] / cnt if cnt else 0.0)
                         for pose, cnt in pose_count.items()}
    vis_acc_by_class = {vis: (vis_correct[vis] / cnt if cnt else 0.0)
                        for vis, cnt in vis_count.items()}

    return (
        avg_iou_all,
        avg_iou_matched,
        pose_pairs,
        vis_pairs,
        matched,
        total_gt,
        pose_acc_by_class,
        vis_acc_by_class
    )


def main(gt_path, user_source, iou_threshold=0.5):
    gt_annos = load_annotations(gt_path)
    user_annos = load_annotations(user_source)
    user_count = len(user_annos)

    avg_all, avg_matched, pose_pairs, vis_pairs, matched, total, pose_acc_by_class, vis_acc_by_class = \
        evaluate(gt_annos, user_annos, iou_threshold)

    # Detection metrics
    tp = matched
    fn = total - matched
    fp = user_count - matched
    precision = tp / (tp + fp) if (tp + fp) else 0.0
    recall = tp / total if total else 0.0
    f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0

    # Percent Agreement for attributes
    percent_pose = sum(1 for gt, usr in pose_pairs if gt == usr) / len(pose_pairs) if pose_pairs else 0.0
    percent_vis  = sum(1 for gt, usr in vis_pairs  if gt == usr) / len(vis_pairs)  if vis_pairs  else 0.0

    # Cohen's kappa for attributes
    kappa_pose = None
    kappa_vis  = None
    if pose_pairs:
        gt_pose_labels, usr_pose_labels = zip(*pose_pairs)
        kappa_pose = cohen_kappa_score(gt_pose_labels, usr_pose_labels)
    if vis_pairs:
        gt_vis_labels, usr_vis_labels = zip(*vis_pairs)
        kappa_vis = cohen_kappa_score(gt_vis_labels, usr_vis_labels)

    # Output metrics
    print(f"TP: {tp}")
    print(f"FN: {fn}")
    print(f"FP: {fp}")
    print(f"\nDetection Metrics:")
    print(f"  Precision: {precision:.4f}")
    print(f"  Recall:    {recall:.4f}")
    print(f"  F1 Score:  {f1:.4f}")

    # Percent Agreement output
    print(f"\nPercent Agreement (Pose):       {percent_pose:.4f}")
    print(f"Percent Agreement (Visibility): {percent_vis:.4f}")

    # Cohen's kappa output
    if kappa_pose is not None:
        print(f"\nCohen's κ (Pose):       {kappa_pose:.4f}")
    if kappa_vis is not None:
        print(f"Cohen's κ (Visibility): {kappa_vis:.4f}")

    # Existing outputs
    print(f"\nTotal GT boxes evaluated: {total}")
    print(f"User annotated boxes:     {user_count}")
    print(f"Matched (IoU >= {iou_threshold}): {matched}")
    print(f"No Matched (IoU < {iou_threshold}): {fn}")
    print(f"\nAverage IoU (all):           {avg_all:.4f}")
    print(f"Average IoU (matched only): {avg_matched:.4f}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Evaluate annotation IoU and attribute accuracy with counts',
    )
    parser.add_argument('--gt', default='agreements.json', help='Path to ground truth JSON or folder')
    parser.add_argument('--user', default='annotations', help='Path to user annotation JSON file or directory')
    parser.add_argument('--iou-th', type=float, default=0.5, help='IoU threshold for matching')
    args = parser.parse_args()
    main(args.gt, args.user, args.iou_th)

